import numpy as np
import cv2
import time
import kerasModels
import sys

model = kerasModels.lenet_5(1.0)

#load weights
model.load_weights('model/keras_model/cnn_weights.h5')

# model.summary()

if __name__ == '__main__':
    count = 0
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    start_time = time.time()
    if len(sys.argv) < 2:
        for j in range(8):
            for i in range(10):
                dir = 'data/data2/%s.%s.jpg'%(i,j+1)
                img = cv2.imread(dir)
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
                img = cv2.GaussianBlur(img,(3,3),0)
                img_array = np.array(img)
                im_data = np.array(np.reshape(img_array, [28, 28]) * 255, dtype=np.float32)
                im_data = im_data / 255.0
                x = im_data.reshape(-1, 1, 28, 28)
                print(x)
                y = model.predict(x)
                output = np.argmax(y)
                print(output, ' and ', i)
                if(output == i):
                    count += 1
        res = count / 80.0
        print('use time: %.3f s, average: %.3f ms' % (time.time() - start_time, (time.time() - start_time) / 0.08))
        print('accuracy is %.1f%%' % (res * 100))
    elif len(sys.argv[1]) > 5:
        img = cv2.imread(sys.argv[1])
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        img = cv2.GaussianBlur(img, (3, 3), 0)
        img_array = np.array(img)
        im_data = np.array(np.reshape(img_array, [28, 28]) * 255, dtype=np.float32)
        x = im_data.reshape(-1, 1, 28, 28)
        y = model.predict(x)

        print("result: the number is %s, using time: %.3f ms"%(np.argmax(y),(time.time()-start_time)*1000))

        cv2.imshow("imput",img)
        if cv2.waitKey(0) == 'q':
            pass
    else:
        while True:
            x,y = int(input()), int(input())
            if x < 10 and x > -1 and y > 0 and y < 9:
                start_time = time.time()
                img = cv2.imread('data/data2/%s.%s.jpg'%(x,y))
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
                img = cv2.GaussianBlur(img, (3, 3), 0)
                img_array = np.array(img)
                im_data = np.array(np.reshape(img_array, [28, 28]) * 255, dtype=np.float32)
                x = im_data.reshape(-1, 1, 28, 28)
                y = model.predict(x)

                print(
                    "result: the number is %s, using time: %.3f ms" % (np.argmax(y), (time.time() - start_time) * 1000))

                cv2.imshow("imput", img)
                cv2.waitKey(10)